//	Maze4DGPUFunctions.metal
//
//	© 2025 by Jeff Weeks
//	See TermsOfUse.txt

#include <metal_stdlib>
using namespace metal;
#include "Maze4DGPUDefinitions.h"	//	Include this *after* "using namespace metal"
									//		if you might need to use uint32_t in "…GPUDefinitions.h"

//	As of June 2023, most iPhones, iPads and Macs support wide color,
//	so we receive all colors as linear Display P3 and then convert them
//	to linear Extended-Range sRGB, as Metal requires when when running apps
//	on iOS and iPadOS (and presumably also when running "designed for iPad" apps
//	on Mac and Vision Pro headsets).
//
//	Design notes:
//
//		For "rainbow edges" it's essential that the color conversion
//		happen in the GPU fragment function Maze4DFragmentFunctionHues().
//
//		For "single-color edges" we could, if we wanted to squeeze out
//		every last bit of performance, let the CPU pre-compute
//		the required XR sRGB colors.  But that small performance boost
//		isn't needed, so I decided to do the color conversion
//		here in the GPU vertex function Maze4DVertexFunctionRGB(),
//		for simplicity and especially for consistency with how
//		we handle the rainbow edges' color conversion
//		(the only difference being that for single-color edges
//		the vertex shader computes the color, while
//		for rainbow edges the fragment shader computes the color).
//
constant bool	gUseWideColor = true;

constant half3x3	gP3toXRsRGB =
					{
						//	Yes, I know, the following values contain
						//	way too many "significant" digits to fit
						//	into a half-precision float, but I'm leaving them
						//	there anyhow, for future reference.
						//	They're harmless.
						{ 1.2249401762805587, -0.0420569547096881, -0.0196375545903344},
						{-0.2249401762805597,  1.0420569547096874, -0.0786360455506319},
						{ 0.0000000000000001,  0.0000000000000000,  1.0982736001409661}
					};

//	RGB

struct VertexInputRGB
{
	float3	pos [[ attribute(VertexAttributePosition)	]];	//	position (x,y,z), in maze coordinates [0, aMazeSize - 1]³
	half3	nor [[ attribute(VertexAttributeNormal)		]];	//	normal vector (nx, ny, nz)
};

struct VertexOutputRGB
{
	float4	position	[[ position		]];
	half4	color		[[ user(color)	]];	//	(r, g, b, 1)
};
struct FragmentInputRGB
{
	half4	color		[[ user(color)	]];	//	(r, g, b, 1)
};

vertex VertexOutputRGB Maze4DVertexFunctionRGB(
	      VertexInputRGB				in				[[ stage_in							]],
 	      constant Maze4DUniformData	&uniformData	[[ buffer(BufferIndexUniforms)		]],
	const device Maze4DInstanceDataRGB	*instanceData	[[ buffer(BufferIndexInstanceData)	]],
	      ushort						iid				[[ instance_id						]]	)
{
	float4			tmpWorldPosition;
	VertexOutputRGB	out;
	half3			tmpNormal;
	half			tmpDiffuseFactor,
					tmpSpecularFactor,
					tmpFogFactor;	//	0.0 = fully fogged;  1.0 = no fog
	half3			tmpRawColor,
					tmpXRsRGBColor,
					tmpShadedColor,
					tmpHighlightedColor,
					tmpFoggedColor;
	
	tmpWorldPosition	= instanceData[iid].itsModelMatrix * float4(in.pos, 1.0);
	out.position		= uniformData.itsViewProjectionMatrix * tmpWorldPosition;

	//	Note that it's OK to apply itsModelMatrix to the normal vector here,
	//	because itsModelMatrix does no dilation except for compressing rainbow edges
	//	in the direction orthogonal to the normal vectors.
	//
	tmpNormal			= ( half4x4(instanceData[iid].itsModelMatrix) * half4(in.nor, 0.0h) ).xyz;
	tmpDiffuseFactor	= 0.50h + 0.50h * max(0.0h, dot(uniformData.itsDiffuseEvaluator, tmpNormal));
	tmpSpecularFactor	= 0.25h * pow(max(0.0h, dot(uniformData.itsSpecularEvaluator, tmpNormal) ), 16.0h);

	tmpFogFactor		= 1.00h - 0.25h * clamp(dot(uniformData.itsFogEvaluator, half4(tmpWorldPosition)), 0.0h, 1.0h);

	tmpRawColor			= instanceData[iid].itsRGB;
	if (gUseWideColor)
	{
		//	Interpret tmpRawColor as linear Display P3
		//	and convert it to linear Extended-Range sRGB.
		tmpXRsRGBColor	= gP3toXRsRGB * tmpRawColor;
	}
	else
	{
		//	Interpret tmpRawColor as linear sRGB.
		tmpXRsRGBColor	= tmpRawColor;
	}
	tmpShadedColor		= tmpDiffuseFactor * tmpXRsRGBColor;
	tmpHighlightedColor	= tmpShadedColor
						+ half3(tmpSpecularFactor, tmpSpecularFactor, tmpSpecularFactor);
	tmpFoggedColor		= tmpFogFactor * tmpHighlightedColor;
	out.color			= half4(tmpFoggedColor, 1.0h);

	return out;
}

fragment half4 Maze4DFragmentFunctionRGB(
	FragmentInputRGB	in	[[ stage_in ]])
{
	return in.color;
}


//	Hues

struct VertexInputHues
{
	float3	pos [[ attribute(VertexAttributePosition)	]];	//	position (x,y,z), in maze coordinates [0, aMazeSize - 1]³
	half3	nor [[ attribute(VertexAttributeNormal)		]];	//	normal vector (nx, ny, nz)
	half	wgt [[ attribute(VertexAttributeWeight)		]];	//	weight ∈ {0.0, 1.0} used to interpolate between itsHues[0] and itsHues[1]
};

struct VertexOutputHues
{
	float4	position		[[ position			]];
	half	hue				[[ user(hue)		]],
			diffuseFactor	[[ user(diffuse)	]],
			specularFactor	[[ user(specular)	]],
			fogFactor		[[ user(fog)		]];	//	0.0 = fully fogged;  1.0 = no fog
};
struct FragmentInputHues
{
	half	hue				[[ user(hue)		]],
			diffuseFactor	[[ user(diffuse)	]],
			specularFactor	[[ user(specular)	]],
			fogFactor		[[ user(fog)		]];	//	0.0 = fully fogged;  1.0 = no fog
};

vertex VertexOutputHues Maze4DVertexFunctionHues(
	      VertexInputHues				in				[[ stage_in							]],
	      constant Maze4DUniformData	&uniformData	[[ buffer(BufferIndexUniforms)		]],
	const device Maze4DInstanceDataHues	*instanceData	[[ buffer(BufferIndexInstanceData)	]],
	      ushort						iid				[[ instance_id						]]	)
{
	float4				tmpWorldPosition;
	VertexOutputHues	out;
	half3				tmpNormal;
	
	tmpWorldPosition	= instanceData[iid].itsModelMatrix * float4(in.pos, 1.0);
	out.position		= uniformData.itsViewProjectionMatrix * tmpWorldPosition;

	out.hue				=       in.wgt      * instanceData[iid].itsHues[0]
						  + (1.0h - in.wgt) * instanceData[iid].itsHues[1];

	//	Note that it's OK to apply itsModelMatrix to the normal vector here,
	//	because itsModelMatrix does no dilation except for compressing rainbow edges
	//	in the direction orthogonal to the normal vectors.
	//
	tmpNormal			= half3( ( instanceData[iid].itsModelMatrix * float4(float3(in.nor), 0.0) ).xyz );
	out.diffuseFactor	= 0.50h + 0.50h * max(0.0h, dot(uniformData.itsDiffuseEvaluator, tmpNormal) );
	out.specularFactor	= 0.25h * pow(max(0.0h, dot(uniformData.itsSpecularEvaluator, tmpNormal) ), 16.0h);

	out.fogFactor		= 1.00h - 0.25h * clamp(dot(uniformData.itsFogEvaluator, half4(tmpWorldPosition)), 0.0h, 1.0h);

	return out;
}

fragment half4 Maze4DFragmentFunctionHues(
	FragmentInputHues	in	[[ stage_in ]])
{
	//	Convert HSV to RGB using the algorithm given at
	//
	//		https://lolengine.net/blog/2013/07/27/rgb-to-hsv-in-glsl
	//
	//	The idea behind the algorithm seems to be as follows.
	//	Think of, say, the red component R as a function of the hue H.
	//	The graph looks like this:
	//
	//		___            ___
	//		   \          /
	//		    \        /
	//		     \______/
	//
	//	The function abs(6*H - 3) gives a sawtooth:
	//
	//		\                /
	//		 \              /
	//		  \............/
	//		   \          /
	//		    \        /
	//		     \....../
	//		      \    /
	//             \  /
	//		        \/
	//
	//	which can be offset by 1 unit vertically and then clamped
	//	to give the desired graph for R as shown in the first figure above.
	
	half3	tmpSawtooth,
			tmpRawColor,
			tmpXRsRGBColor,
			tmpShadedColor,
			tmpHighlightedColor,
			tmpFoggedColor;
	
	tmpSawtooth = abs
					(
							fract
							(
								half3(in.hue, in.hue, in.hue) + half3(1.00000h, 0.66666h, 0.33333h)
							)
							* 6.0h
						-
							half3(3.0h, 3.0h, 3.0h)
					);

	//	The general conversion from HSV to RGB would look like this:
	//
	//		tmpRawColor =
	//			in.value
	//			* mix(	half3(1.0h, 1.0h, 1.0h),									//	pure white
	//					clamp(tmpSawtooth - half3(1.0h, 1.0h, 1.0h), 0.0h, 1.0h),	//	pure saturated color
	//					in.saturation);
	//
	//	but 4D Maze always uses full value and saturation, which simplifies things a bit:
	tmpRawColor			= clamp(tmpSawtooth - half3(1.0h, 1.0h, 1.0h), 0.0h, 1.0h);
	if (gUseWideColor)
	{
		//	Interpret tmpRawColor as linear Display P3
		//	and convert it to linear Extended-Range sRGB.
		tmpXRsRGBColor	= gP3toXRsRGB * tmpRawColor;
	}
	else
	{
		//	Interpret tmpRawColor as linear sRGB.
		tmpXRsRGBColor	= tmpRawColor;
	}
	tmpShadedColor		= in.diffuseFactor * tmpXRsRGBColor;
	tmpHighlightedColor	= tmpShadedColor
						+ half3(in.specularFactor, in.specularFactor, in.specularFactor);
	tmpFoggedColor		= in.fogFactor * tmpHighlightedColor;

	return half4(tmpFoggedColor, 1.0h);
}
